%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Codes for Example 3.1 
% ETKF implemented on a 3-dim SDE
% Created by John Harlim 
% Last edited: March 16, 2018  
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear all, close all
tic
global gamma A1 omega beta
load triad
load triadstats

n = 3;          % dimension of the model
TCYC = 2000;    % total assimilation step
m = 3;          % dimension of the observation (set m=1 to observe only the first component)

% linear observation model
H = zeros(m,n);
H(1:m,1:m) = eye(m);
R = 0.25*var(x(1,:));

% generate observations
y = H*x(:,1:TCYC) + sqrtm(R)*randn(m,TCYC);

% initial conditions
En = 10;    % ensemble size
xa = zeros(n,En,TCYC);
xb = zeros(n,En,TCYC+1);
xb(:,:,1) = .1*randn(n,En);


for k=1:TCYC    
    % update
    xbbar = mean(xb(:,:,k),2);
    X = xb(:,:,k) - xbbar*ones(1,En);
    d = y(:,k) - H*xbbar;
    Y = H*X;        
    JJ = (En-1)*eye(En) + Y'*(R\Y);
    [U,S] = svd(JJ);
    Kc = X*U*pinv(S)*U'*Y'/R;
    T = sqrt(En-1)*U*diag(1./sqrt(diag(S)))*U';
    Xplus = X*T;
    xabar = xbbar + Kc*d;
    xa(:,:,k) = xabar*ones(1,En)+Xplus;
    
    % forecast    
    xb(:,:,k+1) = xa(:,:,k) + DT*triad(xa(:,:,k));
    xb(2:3,:,k+1) = xb(2:3,:,k+1)+sigma*sqrt(DT)*randn(2,En);
        
end
grey = [0.4, 0.4, 0.4];


meanxa = squeeze(mean(xa,2));

rmsa = (meanxa-x(:,1:TCYC)).^2;
rms = sqrt(mean(mean(rmsa)))

for i=1:En
    Xa(:,:,i) = squeeze(xa(:,i,:)) - meanxa;
end
Padiag = sum(Xa.^2,3)/(En-1);
spread = sqrt(mean(mean(Padiag)))

toc

figure(1)
for j=1:3
    subplot(3,1,j)
    hold on
plot([DT:DT:DT*TCYC],meanxa(j,:),'color',grey,'linewidth',2)
    plot([DT:DT:DT*TCYC],x(j,1:TCYC),'k--')    
    hold off
    if (j==1)
        ylabel('u')
    elseif (j==2)
        ylabel('v')
    else
        ylabel('w')
        xlabel('time')
    end
end

%print -depsc -r100 etkf.eps
